import json
import logging
import math
import os
from typing import Any, Dict, List

from rtp_llm.config.gpt_init_model_parameters import GptInitModelParameters
from rtp_llm.model_factory_register import register_model
from rtp_llm.models.base_model import BaseModel
from rtp_llm.models.llama_weight import GemmaWeightInfo, LlamaWeightInfo


def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
    return multiple_of * (
        (int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of
    )


class Llama(BaseModel):
    @staticmethod
    def get_mscale(scale: float):
        if scale <= 1:
            return 1.0
        return 0.1 * math.log(scale) + 1.0

    @staticmethod
    def get_weight_cls():
        return LlamaWeightInfo

    @classmethod
    def _create_config(cls, ckpt_path: str):
        config = GptInitModelParameters(
            head_num=0,
            size_per_head=0,
            layer_num=0,
            max_seq_len=0,
            vocab_size=0,
            ckpt_path=ckpt_path,
            activation_type="SiGLU",
            norm_type="rmsnorm",
            rotary_embedding_dim=128,
            rotary_embedding_style=1,
            has_post_decoder_layernorm=True,
        )
        # hugggingface
        config_path = os.path.join(ckpt_path, "config.json")
        # llama-int8
        param_path = os.path.join(ckpt_path, "params.json")
        if os.path.exists(config_path):
            with open(config_path) as reader:
                content = reader.read()
                content = content.replace("LlamaForCausalLM", "LLaMAForCausalLM")
                config_json = json.loads(content)
            Llama.from_huggingface(config, config_json)
        elif os.path.exists(param_path):
            logging.info("llama not find config.json, use default config")
            with open(param_path) as reader:
                param_json = json.loads(reader.read())
                config_json = param_json
            Llama.from_params(config, param_json)
        else:
            raise Exception("llama parameter from unkown source")

        return config

    @staticmethod
    def from_huggingface(config, config_json: Dict[str, Any]):
        config.head_num = config_json["num_attention_heads"]
        config.head_num_kv = config_json.get("num_key_value_heads", config.head_num)
        config.hidden_size = config_json["hidden_size"]
        config.size_per_head = (
            config_json["hidden_size"] // config_json["num_attention_heads"]
        )
        config.size_per_head = config_json.get("head_dim", config.size_per_head)
        config.layer_num = config_json["num_hidden_layers"]
        config.max_seq_len = config_json.get("max_sequence_length", 2048)
        config.vocab_size = config_json["vocab_size"]
        config.layernorm_eps = config_json.get(
            "rms_norm_eps", config_json.get("layer_norm_eps", 1e-05)
        )
        config.inter_size = config_json["intermediate_size"]
        config.rotary_embedding_base = config_json.get("rope_theta", 10000)
        config.rotary_embedding_dim = config.size_per_head
        config.tie_word_embeddings = config_json.get("tie_word_embeddings", False)
        rope_scaling = config_json.get("rope_scaling")
        if rope_scaling is not None:
            rope_type = rope_scaling.get("type", rope_scaling.get("rope_type"))
            if rope_type == "linear":
                config.rotary_embedding_scale = rope_scaling["factor"]
                config.org_embedding_max_pos = config_json.get(
                    "max_position_embeddings", 2048
                )
            elif rope_type == "dynamic":
                config.rotary_embedding_style = 3
            elif rope_type == "yarn":
                config.rotary_embedding_style = 5
                config.rotary_embedding_scale = rope_scaling["factor"]
                config.rotary_factor1 = rope_scaling.get("beta_slow", 1)
                config.rotary_factor2 = rope_scaling.get("beta_fast", 32)
                config.org_embedding_max_pos = rope_scaling[
                    "original_max_position_embeddings"
                ]
                config.rotary_embedding_mscale = Llama.get_mscale(
                    config.rotary_embedding_scale
                )
            elif rope_type == "llama3":
                config.rotary_embedding_style = 6
                config.rotary_embedding_scale = rope_scaling["factor"]
                config.rotary_factor1 = rope_scaling["low_freq_factor"]
                config.rotary_factor2 = rope_scaling["high_freq_factor"]
                config.org_embedding_max_pos = rope_scaling[
                    "original_max_position_embeddings"
                ]
            else:
                raise Exception(f"unsupport rope_scaling {rope_scaling}")
        # config.activation_type = config_json.get("hidden_act", config.activation_type)
        config.special_tokens.bos_token_id = config_json.get("bos_token_id", -1)
        eos_token_id = config_json.get("eos_token_id", 0)
        # openai endpoint will get corrent eos token id list from tokenizer
        if isinstance(eos_token_id, list):
            config.special_tokens.eos_token_id = eos_token_id[0]
        else:
            config.special_tokens.eos_token_id = eos_token_id
        config.use_logn_attn = config_json.get("use_logn_attn", False)
        config.config_dtype = config_json.get("torch_dtype", None)

    @staticmethod
    def from_params(config: GptInitModelParameters, params_json: Dict[str, Any]):
        config.head_num = params_json["n_heads"]
        config.head_num_kv = params_json.get("n_kv_heads", config.head_num)
        config.size_per_head = params_json["dim"] // params_json["n_heads"]
        config.layer_num = params_json["n_layers"]
        config.max_seq_len = 2048
        config.vocab_size = 32000
        config.layernorm_eps = params_json["norm_eps"]
        config.inter_size = compute_intermediate_size(
            params_json["dim"],
            params_json.get("ffn_dim_multiplier", 1),
            params_json["multiple_of"],
        )
        config.special_tokens.bos_token_id = 1
        config.special_tokens.eos_token_id = 2
        config.rotary_embedding_dim = config.size_per_head
        config.tie_word_embeddings = params_json.get("tie_word_embeddings", False)
        config.config_dtype = params_json.get("torch_dtype", None)
        return config


class Baichuan(Llama):
    @classmethod
    def _create_config(cls, ckpt_path: str):
        config = Llama._create_config(ckpt_path)
        if config.layer_num == 40:  # 13B
            config.rotary_embedding_style = 0
            config.rotary_embedding_dim = 0
            config.use_attention_linear_bias = True
        config.special_tokens.bos_token_id = -1
        return config


class Baichuan2(Baichuan):
    @classmethod
    def _create_config(cls, ckpt_path: str):
        config = Baichuan._create_config(ckpt_path)
        config.normalize_lm_head_weight = True
        return config


class Gemma(Llama):
    def __init__(self, config: GptInitModelParameters):
        if config.py_env_configs.fmha_config.enable_open_source_fmha:
            logging.warn(
                "opensource fmha does not support head dim 256, thus disabled for gemma model"
            )
            os.environ["ENABLE_OPENSOURCE_FMHA"] = "OFF"
        super().__init__(config)

    @staticmethod
    def get_weight_cls():
        return GemmaWeightInfo

    @classmethod
    def _create_config(cls, ckpt_path: str):
        config = Llama._create_config(ckpt_path)
        config.has_post_decoder_layernorm = True
        config.input_embedding_scalar = config.hidden_size**0.5
        config.rotary_embedding_dim = config.size_per_head
        config.activation_type = "gated-gelu"
        return config


class Cohere(Llama):
    @classmethod
    def _create_config(cls, ckpt_path: str):
        config = Llama._create_config(ckpt_path)
        config.rotary_embedding_style = 0
        config.norm_type = "layernorm"
        config.qk_norm = True
        return config


register_model("internlm", Llama, ["InternLMForCausalLM"])
register_model("internlm2", Llama, ["InternLM2ForCausalLM"])
register_model("llama", Llama, ["LlamaForCausalLM", "YiForCausalLM"])
register_model("xverse", Llama, ["XverseForCausalLM"])
register_model("aquila", Llama, ["AquilaModel"])
register_model("mistral", Llama, ["MistralForCausalLM"])
register_model("baichuan", Baichuan, ["BaichuanForCausalLM"])
register_model("baichuan2", Baichuan2)
register_model("gemma", Gemma, ["GemmaForCausalLM"])
register_model("cohere", Cohere, ["CohereForCausalLM"])
